# helper functions for replication scripts
# Zhai & Garside 2022

if (!require("pacman")) install.packages("pacman")
pacman::p_load(tidyverse, fixest)

# R_01_main_text ----------------------------------------------------------

fit_twfe <- function(df, fml, fe, cl) {
  feols(
    data = df,
    fml = as.formula(fml),
    fixef = fe,
    cluster = cl
  )
}

# R_02_appendix_02_rand_checks --------------------------------------------

test_balance <- function(var_id, df, d = "flooded") {
  m <- paste0(var_id, "~factor(", d, ")*factor(year)")
  fit <- fixest::feols(fml = as.formula(m), data = df, fixef = "gen_id", cluster = "gen_id")
  fit[["coeftable"]] %>% as.data.frame() %>% rownames_to_column(var = "Term")
}

tidy_btest_distance <- function(fixest_fit) {
  {fixest_fit[["coeftable"]]} %>% 
    as.data.frame() %>% rownames_to_column(var = "Term") %>% 
    filter(!grepl("intercept", Term, ignore.case = TRUE)) %>% 
    add_column(covar = "distance", .before = everything(.))
}

tidy_best_all <- function(btest, btest_dist, var_id) {
  btest %>% 
    `names<-`(var_id) %>% 
    bind_rows(.id = "covar") %>% 
    filter(grepl(":", Term)) %>% 
    rbind.data.frame(btest_dist) 
}

test_pt <- function(df, fml) {
  feols(data = df,
        fml = fml,
        fixef = c("gen_id", "date"),
        cluster = c("GMD", "Election")) 
}

# R_02_appendix_03_further_results ----------------------------------------

fit_twfe2 <- function(df, fml, fe, cl, wt) {
  feols(
    data = df,
    fml = as.formula(fml),
    fixef = fe,
    cluster = cl,
    weights = wt
  )
}

fit_twfe2_logit <- function(df, fml, fe, cl, wt, link = "binomial") {
  feglm(
    data = df,
    fml = as.formula(fml),
    fixef = fe,
    cluster = cl,
    weights = wt,
    family = link
  )
}

# R_02_appendix_04_robust_checks ------------------------------------------

make_state_dfs <- function(df, state_id = IDs.state, state_id_var = land_name) {
  state_id_var <- enquo(state_id_var)
  df.states <- df %>% 
    filter(!!state_id_var %in% state_id)
  state_id <- pull(df.states, !!state_id_var)
  df.states %>% split(state_id)
}

drop_incumbent <- function(fml) {
  map(fml, function(fm) paste(format(fm), collapse = "") %>% gsub("\\+ factor\\(incumbent\\)","",.) %>% as.formula())
}

fit_sur <- function(df, fml) {
  systemfit(formula = fml, method = "SUR", data = df)
}
sfit_sur0 <- safely(fit_sur, otherwise = NA)
sfit_sur <- function(df, fml) {
  sfit_sur0(df=df, fml=fml) %>% .[["result"]]
}

extract_sur_est <- function(out, state_id, lhs) {
  est <- vector("list", length = length(out))
  for (i in seq_along(est)) {
    if (length(out[[i]][[state_id]])==1) est[[i]] <- data.frame(
      "party" = NA,"mu" = NA,"se"=NA, "lwr"=NA,"upr"=NA
    )
    else{
      est[[i]] <- 
        map(
          list(out[[i]][[state_id]]$coefficients, out[[i]][[state_id]]$coefCov %>% diag() %>% sqrt()),
          ~{.x[grepl("factor\\((flooded|flooded2)\\)1\\:factor\\(date\\)2021-09-26",names(.x))]}) %>%
        map(~purrr::set_names(.x, lhs)) %>%
        `names<-`(c("mu","se")) %>%
        bind_rows(.id = "measure") %>%
        pivot_longer(cols = -measure, names_to = "party", values_to = "est") %>%
        pivot_wider(names_from = measure, values_from = est) %>%
        mutate(lwr = mu-1.96*se, upr = mu+1.96*se) %>%
        mutate(party = case_when(
          party == "v_cducsu_pct" ~"CDU/CSU",
          party == "v_spd_pct" ~ "SPD",
          party == "v_afd_pct" ~ "AfD",
          party == "v_fdp_pct" ~ "FDP",
          party == "v_dielinke_pct" ~ "Die Linke",
          party == "v_green_pct" ~ "Greens"
        ))
    }   
  }
  names(est) <- c("Primary","Secondary")
  return(est)
}

tidy_sur_est <- function(out, state_ids = IDs.state, lhs) {
  map(state_ids, ~extract_sur_est(out = out, state_id = .x, lhs = lhs)) %>% 
    `names<-`(state_ids) %>% 
    map(~bind_rows(.x, .id = "measure")) %>% 
    bind_rows(.id = "state") %>% 
    na.omit()
}

plot_sur_est <- function(out, state_ids = IDs.state, lhs) {
  out %>% 
    tidy_sur_est(out = ., state_ids = state_ids, lhs = lhs) %>% 
    ggplot() +
    geom_pointrange(aes(x=party, y=mu, ymin=lwr, ymax=upr, color = party, shape = party)) +
    facet_grid(measure~state) +
    geom_hline(yintercept = 0, color = "darkred", lty = "dotted") +
    coord_flip() +
    ggthemes::scale_color_stata() +
    labs(x = NULL, y = NULL, color = "Party", shape = "Party") +
    scale_y_continuous(breaks = scales::pretty_breaks(n=5)) +
    theme_minimal(base_size = 14)
}

test_h3a <- function(df, D = flooded, measure = NULL) {
  D = enquo(D)
  df %>% 
    nest(-!!D) %>% 
    mutate(corr = map(data, ~cor.test(.$turnout, .x$v_green_pct)) %>% 
             map(., broom::tidy)) %>% 
    unnest(corr) %>% 
    select(-data) %>% 
    add_column(measure = measure) %>% 
    rename(flooded=1)
}

# R_03_simulation_analysis ------------------------------------------------

sim_corrv <- function(x, rho) { 
  n <- length(x)                 
  theta <- acos(rho)             
  x1    <- rnorm(n)      
  X     <- cbind(x, x1)         
  Xctr  <- scale(X, center=TRUE, scale=FALSE)   
  Id   <- diag(n)                               
  Q    <- qr.Q(qr(Xctr[ , 1, drop=FALSE]))      
  P    <- tcrossprod(Q)          
  x1o  <- (Id-P) %*% Xctr[ , 2]                 
  Xc1  <- cbind(Xctr[ , 1], x1o)                
  Y    <- Xc1 %*% diag(1/sqrt(colSums(Xc1^2)))  
  Y[ , 2] + (1 / tan(theta)) * Y[ , 1]     
}

sim_dplus <- function(x, p) {
  xis <- which(x==1)
  x0 <- rep(0, length(x))
  x0[xis] <- rbinom(length(xis), size = 1, p)
  x1 <- ifelse(x0==1,1,0)
  x1
}

sim_slag <- function(x, w, t) {
  ts <- sort(unique(t))
  sapply(ts, function(ti) w %*% x[t==ti] ) %>% as.vector
}

tidy_model <- function(m, cluster_id = "units") {
  m.summary <- summary(m, cluster = cluster_id) %>%
    broom::tidy()
  m.ci <- confint(m) %>% 
    `rownames<-`(NULL) %>%
    `colnames<-`(c("conf.low", "conf.high"))
  cbind.data.frame(m.summary, m.ci)
}
